{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "54add068",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第9章/加载编码工具\n",
    "from transformers import BertTokenizer\n",
    "\n",
    "token = BertTokenizer.from_pretrained('bert-base-chinese')\n",
    "\n",
    "token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "68ad14f8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids torch.Size([2, 18])\n",
      "token_type_ids torch.Size([2, 18])\n",
      "length torch.Size([2])\n",
      "attention_mask torch.Size([2, 18])\n",
      "[CLS] 不 是 一 切 大 树 ， [SEP] 都 被 风 暴 折 断 。 [SEP] [PAD]\n"
     ]
    }
   ],
   "source": [
    "#第9章/试编码句子\n",
    "out = token.batch_encode_plus(\n",
    "    batch_text_or_text_pairs=[('不是一切大树，', '都被风暴折断。'),\n",
    "                              ('不是一切种子，', '都找不到生根的土壤。')],\n",
    "    truncation=True,\n",
    "    padding='max_length',\n",
    "    max_length=18,\n",
    "    return_tensors='pt',\n",
    "    return_length=True,\n",
    ")\n",
    "\n",
    "#查看编码输出\n",
    "for k, v in out.items():\n",
    "    print(k, v.shape)\n",
    "\n",
    "#把编码还原为句子\n",
    "print(token.decode(out['input_ids'][0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3b6bc44a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/train/cache-9d6ffbe8cbc206e6.arrow\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(8001, '地理位置佳，在市中心。酒店服务好、早餐品', '大。人品不好，发现1个暗点，都没地哭去。', 1)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第9章/定义数据集\n",
    "import torch\n",
    "from datasets import load_from_disk\n",
    "import random\n",
    "\n",
    "\n",
    "class Dataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, split):\n",
    "        dataset = load_from_disk('./data/ChnSentiCorp')[split]\n",
    "\n",
    "        def f(data):\n",
    "            return len(data['text']) > 40\n",
    "\n",
    "        self.dataset = dataset.filter(f)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        text = self.dataset[i]['text']\n",
    "\n",
    "        #切分一句话为前半句和后半句\n",
    "        sentence1 = text[:20]\n",
    "        sentence2 = text[20:40]\n",
    "\n",
    "        #随机整数，取值范围为0和1\n",
    "        label = random.randint(0, 1)\n",
    "\n",
    "        #有一半的概率把后半句替换为一句无关的话\n",
    "        if label == 1:\n",
    "            j = random.randint(0, len(self.dataset) - 1)\n",
    "            sentence2 = self.dataset[j]['text'][20:40]\n",
    "\n",
    "        return sentence1, sentence2, label\n",
    "\n",
    "\n",
    "dataset = Dataset('train')\n",
    "\n",
    "sentence1, sentence2, label = dataset[7]\n",
    "\n",
    "len(dataset), sentence1, sentence2, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a4205cea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第9章/定义计算设备\n",
    "import torch\n",
    "\n",
    "device = 'cpu'\n",
    "if torch.cuda.is_available():\n",
    "    device = 'cuda'\n",
    "\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ae7c342b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#第9章/数据整理函数\n",
    "def collate_fn(data):\n",
    "    sents = [i[:2] for i in data]\n",
    "    labels = [i[2] for i in data]\n",
    "\n",
    "    #编码\n",
    "    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,\n",
    "                                   truncation=True,\n",
    "                                   padding='max_length',\n",
    "                                   max_length=45,\n",
    "                                   return_tensors='pt',\n",
    "                                   return_length=True,\n",
    "                                   add_special_tokens=True)\n",
    "\n",
    "    #input_ids:编码之后的数字\n",
    "    #attention_mask:是补零的位置是0,其他位置是1\n",
    "    #token_type_ids:第一个句子和特殊符号的位置是0,第二个句子的位置是1\n",
    "    input_ids = data['input_ids'].to(device)\n",
    "    attention_mask = data['attention_mask'].to(device)\n",
    "    token_type_ids = data['token_type_ids'].to(device)\n",
    "    labels = torch.LongTensor(labels).to(device)\n",
    "\n",
    "    return input_ids, attention_mask, token_type_ids, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f286f7f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS] 酒 店 还 是 非 常 的 不 错 ， 我 预 定 的 是 套 间 ， 服 务 [SEP] 非 常 好 ， 随 叫 随 到 ， 结 帐 非 常 快 。 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([3, 45]),\n",
       " torch.Size([3, 45]),\n",
       " torch.Size([3, 45]),\n",
       " tensor([0, 0, 1], device='cuda:0'))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第9章/数据整理函数试算\n",
    "#模拟一批数据\n",
    "data = [('酒店还是非常的不错，我预定的是套间，服务', '非常好，随叫随到，结帐非常快。', 0),\n",
    "        ('外观很漂亮,性价比感觉还不错，功能简', '单,适合出差携带。蓝牙摄象头都有了。', 0),\n",
    "        ('《穆斯林的葬礼》我已闻名很久，只是一直没', '怎能享受4星的服务，连空调都不能用的。', 1)]\n",
    "\n",
    "#试算\n",
    "input_ids, attention_mask, token_type_ids, labels = collate_fn(data)\n",
    "\n",
    "#把编码还原为句子\n",
    "print(token.decode(input_ids[0]))\n",
    "\n",
    "input_ids.shape, attention_mask.shape, token_type_ids.shape, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d3d5af04",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1000"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第9章/数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                     batch_size=8,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "len(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "27de7cfd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([8, 45]),\n",
       " torch.Size([8, 45]),\n",
       " torch.Size([8, 45]),\n",
       " tensor([0, 1, 0, 0, 1, 0, 0, 0], device='cuda:0'))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第9章/查看数据样例\n",
    "for i, (input_ids, attention_mask, token_type_ids,\n",
    "        labels) in enumerate(loader):\n",
    "    break\n",
    "\n",
    "input_ids.shape, attention_mask.shape, token_type_ids.shape, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c558bfa4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "10226.7648"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第9章/加载预训练模型\n",
    "from transformers import BertModel\n",
    "\n",
    "pretrained = BertModel.from_pretrained('bert-base-chinese')\n",
    "\n",
    "#统计参数量\n",
    "sum(i.numel() for i in pretrained.parameters()) / 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2f9d0a31",
   "metadata": {},
   "outputs": [],
   "source": [
    "#第9章/不训练预训练模型,不需要计算梯度\n",
    "for param in pretrained.parameters():\n",
    "    param.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4ee2de49",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 45, 768])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第9章/预训练模型试算\n",
    "#设定计算设备\n",
    "pretrained.to(device)\n",
    "\n",
    "#模型试算\n",
    "out = pretrained(input_ids=input_ids,\n",
    "                 attention_mask=attention_mask,\n",
    "                 token_type_ids=token_type_ids)\n",
    "\n",
    "out.last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "4a96638c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 2])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第9章/定义下游任务模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc = torch.nn.Linear(768, 2)\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, token_type_ids):\n",
    "        #使用预训练模型抽取数据特征\n",
    "        with torch.no_grad():\n",
    "            out = pretrained(input_ids=input_ids,\n",
    "                             attention_mask=attention_mask,\n",
    "                             token_type_ids=token_type_ids)\n",
    "\n",
    "        #对抽取的特征只取第一个字的结果做分类即可\n",
    "        out = self.fc(out.last_hidden_state[:, 0])\n",
    "\n",
    "        out = out.softmax(dim=1)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "#设定计算设备\n",
    "model.to(device)\n",
    "\n",
    "#试算\n",
    "model(input_ids=input_ids,\n",
    "      attention_mask=attention_mask,\n",
    "      token_type_ids=token_type_ids).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "9b99dbd4",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.681793749332428 4.995e-05 0.625\n",
      "20 0.6200282573699951 4.8950000000000004e-05 0.875\n",
      "40 0.608764111995697 4.795e-05 0.75\n",
      "60 0.6029558181762695 4.695e-05 0.75\n",
      "80 0.49943840503692627 4.5950000000000006e-05 1.0\n",
      "100 0.5959280729293823 4.495e-05 0.75\n",
      "120 0.5847670435905457 4.3950000000000004e-05 0.75\n",
      "140 0.5122112035751343 4.295e-05 1.0\n",
      "160 0.4979666471481323 4.195e-05 1.0\n",
      "180 0.5046210289001465 4.095e-05 0.875\n",
      "200 0.5604161024093628 3.995e-05 0.75\n",
      "220 0.4611266255378723 3.8950000000000005e-05 1.0\n",
      "240 0.44622868299484253 3.795e-05 1.0\n",
      "260 0.41288116574287415 3.6950000000000004e-05 1.0\n",
      "280 0.5807414650917053 3.595e-05 0.75\n",
      "300 0.43091052770614624 3.495e-05 1.0\n",
      "320 0.4966319799423218 3.3950000000000005e-05 0.875\n",
      "340 0.5417488813400269 3.295e-05 0.875\n",
      "360 0.44159558415412903 3.1950000000000004e-05 1.0\n",
      "380 0.4915757179260254 3.095e-05 0.75\n",
      "400 0.4184477627277374 2.995e-05 1.0\n",
      "420 0.5038822889328003 2.895e-05 0.875\n",
      "440 0.46791255474090576 2.7950000000000005e-05 0.875\n",
      "460 0.46281614899635315 2.6950000000000005e-05 0.875\n",
      "480 0.38066813349723816 2.595e-05 1.0\n",
      "500 0.5690062642097473 2.495e-05 0.75\n",
      "520 0.4103327691555023 2.395e-05 1.0\n",
      "540 0.5749650001525879 2.2950000000000002e-05 0.75\n",
      "560 0.6115520596504211 2.195e-05 0.75\n",
      "580 0.467792809009552 2.095e-05 0.75\n",
      "600 0.5293932557106018 1.995e-05 0.875\n",
      "620 0.49104464054107666 1.895e-05 0.875\n",
      "640 0.3909786641597748 1.795e-05 1.0\n",
      "660 0.4154558777809143 1.6950000000000002e-05 1.0\n",
      "680 0.5874401926994324 1.595e-05 0.625\n",
      "700 0.6164826154708862 1.4950000000000001e-05 0.75\n",
      "720 0.3789943754673004 1.3950000000000002e-05 1.0\n",
      "740 0.36395886540412903 1.2950000000000001e-05 1.0\n",
      "760 0.40311199426651 1.195e-05 1.0\n",
      "780 0.5107507109642029 1.095e-05 0.875\n",
      "800 0.4358420670032501 9.950000000000001e-06 0.875\n",
      "820 0.42680349946022034 8.95e-06 0.875\n",
      "840 0.4425410330295563 7.95e-06 1.0\n",
      "860 0.5414326786994934 6.950000000000001e-06 0.75\n",
      "880 0.5244377851486206 5.95e-06 0.875\n",
      "900 0.6650505661964417 4.950000000000001e-06 0.5\n",
      "920 0.3545813262462616 3.95e-06 1.0\n",
      "940 0.526192307472229 2.95e-06 0.75\n",
      "960 0.48736572265625 1.95e-06 0.875\n",
      "980 0.5858534574508667 9.5e-07 0.75\n"
     ]
    }
   ],
   "source": [
    "#第9章/训练\n",
    "from transformers import AdamW\n",
    "from transformers.optimization import get_scheduler\n",
    "\n",
    "\n",
    "def train():\n",
    "    #定义优化器\n",
    "    optimizer = AdamW(model.parameters(), lr=5e-5)\n",
    "\n",
    "    #定义loss函数\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "    #定义学习率调节器\n",
    "    scheduler = get_scheduler(name='linear',\n",
    "                              num_warmup_steps=0,\n",
    "                              num_training_steps=len(loader),\n",
    "                              optimizer=optimizer)\n",
    "\n",
    "    #模型切换到训练模式\n",
    "    model.train()\n",
    "\n",
    "    #按批次遍历训练集中的数据\n",
    "    for i, (input_ids, attention_mask, token_type_ids,\n",
    "            labels) in enumerate(loader):\n",
    "\n",
    "        #模型计算\n",
    "        out = model(input_ids=input_ids,\n",
    "                    attention_mask=attention_mask,\n",
    "                    token_type_ids=token_type_ids)\n",
    "\n",
    "        #计算loss并使用梯度下降法优化模型参数\n",
    "        loss = criterion(out, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        #输出各项数据的情况，便于观察\n",
    "        if i % 20 == 0:\n",
    "            out = out.argmax(dim=1)\n",
    "            accuracy = (out == labels).sum().item() / len(labels)\n",
    "            lr = optimizer.state_dict()['param_groups'][0]['lr']\n",
    "            print(i, loss.item(), lr, accuracy)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "ea603e6c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/ChnSentiCorp/test/cache-0092adaf649d6010.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "0.89375\n"
     ]
    }
   ],
   "source": [
    "#第9章/测试\n",
    "def test():\n",
    "    #定义测试数据集加载器\n",
    "    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),\n",
    "                                              batch_size=32,\n",
    "                                              collate_fn=collate_fn,\n",
    "                                              shuffle=True,\n",
    "                                              drop_last=True)\n",
    "\n",
    "    #下游任务模型切换到运行模式\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    #按批次遍历测试集中的数据\n",
    "    for i, (input_ids, attention_mask, token_type_ids,\n",
    "            labels) in enumerate(loader_test):\n",
    "\n",
    "        #计算5个批次即可，不需要全部遍历\n",
    "        if i == 5:\n",
    "            break\n",
    "\n",
    "        print(i)\n",
    "\n",
    "        #计算\n",
    "        with torch.no_grad():\n",
    "            out = model(input_ids=input_ids,\n",
    "                        attention_mask=attention_mask,\n",
    "                        token_type_ids=token_type_ids)\n",
    "\n",
    "        pred = out.argmax(dim=1)\n",
    "\n",
    "        #统计正确率\n",
    "        correct += (pred == labels).sum().item()\n",
    "        total += len(labels)\n",
    "\n",
    "    print(correct / total)\n",
    "\n",
    "\n",
    "test()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
